from einops import rearrange
import torch

# 假设有一个4维张量，形状为 (batch_size, height, width, channels)
tensor = torch.rand((2, 3, 4, 5))

# 使用 einops 中的 rearrange 函数重新排列张量的维度
tensor_rearranged = rearrange(tensor, 'b h w c -> b (h w) c')

print("原始张量：\n", tensor.shape)
print("\n重新排列后的张量：\n", tensor_rearranged.shape)
